{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Document Chunking With LangChain Document Splitters\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/elastic/elasticsearch-labs/blob/main/notebooks/document-chunking/with-langchain-splitters.ipynb)\n",
    "\n",
    "**Using Elasticsearch Nested Dense Vector Support**\n",
    "\n",
    "This interactive notebook will:\n",
    "- load the model \"sentence-transformers__all-minilm-l6-v2\" from Hugging Face and into Elasticsearch ML Node\n",
    "- Use LangChain splitters to chunk the passages into sentences and index them into Elasticsearch with nested dense vector\n",
    "- perform a search and return docs with the most relevant passages"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dependencies\n",
    "In this notebook, we're going to use Langchain and the Elasticsearch python client.\n",
    "\n",
    "We will also require a running Elasticsearch instance with an ML node and model deployed to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python3 -m pip install -qU langchain langchain-elasticsearch elasticsearch eland==8.12.1 jq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Connect to Elasticsearch\n",
    "\n",
    "ℹ️ We're using an Elastic Cloud deployment of Elasticsearch for this notebook. If you don't have an Elastic Cloud deployment, sign up [here](https://cloud.elastic.co/registration?utm_source=github&utm_content=elasticsearch-labs-notebook) for a free trial. \n",
    "\n",
    "We'll use the **Cloud ID** to identify our deployment, because we are using Elastic Cloud deployment. To find the Cloud ID for your deployment, go to https://cloud.elastic.co/deployments and select your deployment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from getpass import getpass\n",
    "\n",
    "# https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud#finding-your-cloud-id\n",
    "ELASTIC_CLOUD_ID = getpass(\"Elastic Cloud ID: \")\n",
    "\n",
    "# https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud#creating-an-api-key\n",
    "ELASTIC_API_KEY = getpass(\"Elastic Api Key: \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from elasticsearch import Elasticsearch\n",
    "\n",
    "client = Elasticsearch(cloud_id=ELASTIC_CLOUD_ID, api_key=ELASTIC_API_KEY)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download our example Dataset\n",
    "We are going to use Langchain's tooling to ingest and split raw documents into smaller chunks. We are using our example workplace search dataset.\n",
    "\n",
    "LangChain has a number of other loaders to ingest data from other sources. See their [core loaders](https://python.langchain.com/docs/modules/data_connection/document_loaders/) or [loaders integration](https://python.langchain.com/docs/integrations/document_loaders) for more information. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from urllib.request import urlopen\n",
    "import json\n",
    "\n",
    "url = \"https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json\"\n",
    "\n",
    "response = urlopen(url)\n",
    "data = json.load(response)\n",
    "\n",
    "with open(\"temp.json\", \"w\") as json_file:\n",
    "    json.dump(data, json_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.document_loaders import JSONLoader\n",
    "\n",
    "\n",
    "def metadata_func(record: dict, metadata: dict) -> dict:\n",
    "    metadata[\"name\"] = record.get(\"name\")\n",
    "    metadata[\"summary\"] = record.get(\"summary\")\n",
    "    metadata[\"url\"] = record.get(\"url\")\n",
    "    metadata[\"category\"] = record.get(\"category\")\n",
    "    metadata[\"updated_at\"] = record.get(\"updated_at\")\n",
    "\n",
    "    return metadata\n",
    "\n",
    "\n",
    "# For more loaders https://python.langchain.com/docs/modules/data_connection/document_loaders/\n",
    "# And 3rd party loaders https://python.langchain.com/docs/modules/data_connection/document_loaders/#third-party-loaders\n",
    "loader = JSONLoader(\n",
    "    file_path=\"temp.json\",\n",
    "    jq_schema=\".[]\",\n",
    "    content_key=\"content\",\n",
    "    metadata_func=metadata_func,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Model from hugging face\n",
    "The first thing you will need is a model to create the text embeddings out of the chunks, you can use whatever you would like, but this example will run end to end on the minilm-l6-v2 model. With an Elastic Cloud cluster created or another Elasticsearch cluster ready, we can upload the text embedding model using the eland library."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL_ID = \"sentence-transformers__all-minilm-l6-v2\"\n",
    "\n",
    "!eland_import_hub_model \\\n",
    "    --cloud-id $ELASTIC_CLOUD_ID \\\n",
    "    --es-username elastic \\\n",
    "    --es-api-key $ELASTIC_API_KEY \\\n",
    "    --hub-model-id \"sentence-transformers/all-MiniLM-L6-v2\" \\\n",
    "    --task-type text_embedding \\\n",
    "    --clear-previous \\\n",
    "    --start"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting up our Elasticsearch Index\n",
    "In this example we're going to use a pipeline to do the inference and store the embeddings in our index. \n",
    "\n",
    "In this example, we are using the sentence transformers minilm-l6-v2 model, which you will need to is running on the ML node. With this model, we are setting up an index_pipeline to do the inference and store the embeddings in our index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'nb_parent_retriever_index'})"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "PIPELINE_ID = \"chunk_text_to_passages\"\n",
    "MODEL_DIMS = 384\n",
    "INDEX_NAME = \"nb_parent_retriever_index\"\n",
    "\n",
    "# Create the pipeline\n",
    "client.ingest.put_pipeline(\n",
    "    id=PIPELINE_ID,\n",
    "    processors=[\n",
    "        {\n",
    "            \"foreach\": {\n",
    "                \"field\": \"passages\",\n",
    "                \"processor\": {\n",
    "                    \"inference\": {\n",
    "                        \"field_map\": {\"_ingest._value.text\": \"text_field\"},\n",
    "                        \"model_id\": MODEL_ID,\n",
    "                        \"target_field\": \"_ingest._value.vector\",\n",
    "                        \"on_failure\": [\n",
    "                            {\n",
    "                                \"append\": {\n",
    "                                    \"field\": \"_source._ingest.inference_errors\",\n",
    "                                    \"value\": [\n",
    "                                        {\n",
    "                                            \"message\": \"Processor 'inference' in pipeline 'ml-inference-title-vector' failed with message '{{ _ingest.on_failure_message }}'\",\n",
    "                                            \"pipeline\": \"ml-inference-title-vector\",\n",
    "                                            \"timestamp\": \"{{{ _ingest.timestamp }}}\",\n",
    "                                        }\n",
    "                                    ],\n",
    "                                }\n",
    "                            }\n",
    "                        ],\n",
    "                    }\n",
    "                },\n",
    "            }\n",
    "        }\n",
    "    ],\n",
    ")\n",
    "\n",
    "# Create the index\n",
    "client.indices.create(\n",
    "    index=INDEX_NAME,\n",
    "    settings={\"index\": {\"default_pipeline\": PIPELINE_ID}},\n",
    "    mappings={\n",
    "        \"dynamic\": \"true\",\n",
    "        \"properties\": {\n",
    "            \"passages\": {\n",
    "                \"type\": \"nested\",\n",
    "                \"properties\": {\n",
    "                    \"vector\": {\n",
    "                        \"properties\": {\n",
    "                            \"predicted_value\": {\n",
    "                                \"type\": \"dense_vector\",\n",
    "                                \"index\": True,\n",
    "                                \"dims\": MODEL_DIMS,\n",
    "                                \"similarity\": \"dot_product\",\n",
    "                            }\n",
    "                        }\n",
    "                    }\n",
    "                },\n",
    "            }\n",
    "        },\n",
    "    },\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Utils: Parent Child Splitter Function\n",
    "This function will split a document into multiple passages, and return the parent document with the child passages. \n",
    "\n",
    "It also has an option to chunk the parent document into smaller documents, meaning the parent document will be split into multiple index documents. We will use this in example 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "\n",
    "\n",
    "def parent_child_splitter(documents, chunk_size: int = 200):\n",
    "\n",
    "    child_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size)\n",
    "\n",
    "    docs = []\n",
    "    for i, doc in enumerate(documents):\n",
    "        passages = []\n",
    "\n",
    "        for _doc in child_splitter.split_documents([doc]):\n",
    "            passages.append(\n",
    "                {\n",
    "                    \"text\": _doc.page_content,\n",
    "                }\n",
    "            )\n",
    "\n",
    "        doc = {\n",
    "            \"content\": doc.page_content,\n",
    "            \"metadata\": doc.metadata,\n",
    "            \"passages\": passages,\n",
    "        }\n",
    "        docs.append(doc)\n",
    "\n",
    "    return docs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Utils: Pretty Response\n",
    "This function will print out the response from Elasticsearch in an easier to read format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pretty_response(response, show_parent_text=False):\n",
    "    if len(response[\"hits\"][\"hits\"]) == 0:\n",
    "        print(\"Your search returned no results.\")\n",
    "    else:\n",
    "        for hit in response[\"hits\"][\"hits\"]:\n",
    "            id = hit[\"_id\"]\n",
    "            score = hit[\"_score\"]\n",
    "            doc_title = hit[\"_source\"][\"metadata\"][\"name\"]\n",
    "            parent_text = \"\"\n",
    "\n",
    "            if show_parent_text:\n",
    "                parent_text = hit[\"_source\"][\"content\"]\n",
    "\n",
    "            passage_text = \"\"\n",
    "\n",
    "            for passage in hit[\"inner_hits\"][\"passages\"][\"hits\"][\"hits\"]:\n",
    "                passage_text += passage[\"fields\"][\"passages\"][0][\"text\"][0] + \"\\n\\n\"\n",
    "\n",
    "            pretty_output = f\"\\nID: {id}\\nDoc Title: {doc_title}\\nparent text:\\n{parent_text}\\nPassage Text:\\n{passage_text}\\nScore: {score}\\n\"\n",
    "            print(pretty_output)\n",
    "            print(\"---\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Full Document, nested passages\n",
    "In this example we will split a document into passages, and store the full document as a parent document. We will then store the passages as nested documents, with a link back to the parent document.\n",
    "\n",
    "Below we are using the parent child splitter to split the full documents into passages. The `parent_child_splitter` fn returns a list of documents, with an array of nested passages. \n",
    "\n",
    "We then index these documents into Elasticsearch. This will index the full document and the passages will be stored in a nested field. \n",
    "\n",
    "Our index pipeline processor will then run the inference on the passages, and store the embeddings in the index."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Indexed 15 documents with [] errors\n"
     ]
    }
   ],
   "source": [
    "from elasticsearch import helpers\n",
    "\n",
    "chunked_docs = parent_child_splitter(loader.load(), chunk_size=600)\n",
    "\n",
    "count, errors = helpers.bulk(client, chunked_docs, index=INDEX_NAME)\n",
    "\n",
    "print(f\"Indexed {count} documents with {errors} errors\")\n",
    "\n",
    "import time\n",
    "\n",
    "time.sleep(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Perform a Nested Search\n",
    "We can now perform a nested search, to find the passages that match our query, which will be returned in `inner_hits`. In the example that follows only one passage per parent document is requested."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "ID: 1XvjyowBidHK_OJxJozM\n",
      "Doc Title: Work From Home Policy\n",
      "parent text:\n",
      "\n",
      "Passage Text:\n",
      "Effective: March 2020\n",
      "Purpose\n",
      "\n",
      "The purpose of this full-time work-from-home policy is to provide guidelines and support for employees to conduct their work remotely, ensuring the continuity and productivity of business operations during the COVID-19 pandemic and beyond.\n",
      "Scope\n",
      "\n",
      "This policy applies to all employees who are eligible for remote work as determined by their role and responsibilities. It is designed to allow employees to work from home full time while maintaining the same level of performance and collaboration as they would in the office.\n",
      "Eligibility\n",
      "\n",
      "\n",
      "Score: 0.84830964\n",
      "\n",
      "---\n",
      "\n",
      "ID: 3HvjyowBidHK_OJxJozM\n",
      "Doc Title: Intellectual Property Policy\n",
      "parent text:\n",
      "\n",
      "Passage Text:\n",
      "Purpose\n",
      "The purpose of this Intellectual Property Policy is to establish guidelines and procedures for the ownership, protection, and utilization of intellectual property generated by employees during their employment. This policy aims to encourage creativity and innovation while ensuring that the interests of both the company and its employees are protected.\n",
      "\n",
      "Scope\n",
      "This policy applies to all employees, including full-time, part-time, temporary, and contract employees.\n",
      "\n",
      "\n",
      "Score: 0.7292882\n",
      "\n",
      "---\n",
      "\n",
      "ID: 2XvjyowBidHK_OJxJozM\n",
      "Doc Title: Company Vacation Policy\n",
      "parent text:\n",
      "\n",
      "Passage Text:\n",
      "Purpose\n",
      "\n",
      "The purpose of this vacation policy is to outline the guidelines and procedures for requesting and taking time off from work for personal and leisure purposes. This policy aims to promote a healthy work-life balance and encourage employees to take time to rest and recharge.\n",
      "Scope\n",
      "\n",
      "This policy applies to all full-time and part-time employees who have completed their probationary period.\n",
      "Vacation Accrual\n",
      "\n",
      "\n",
      "Score: 0.7137784\n",
      "\n",
      "---\n",
      "\n",
      "ID: 13vjyowBidHK_OJxJozM\n",
      "Doc Title: Wfh Policy Update May 2023\n",
      "parent text:\n",
      "\n",
      "Passage Text:\n",
      "As we continue to prioritize the well-being of our employees, we are making a slight adjustment to our hybrid work policy. Starting May 1, 2023, employees will be required to work from the office three days a week, with two days designated for remote work. Please communicate with your supervisor and HR department to establish your updated in-office workdays.\n",
      "\n",
      "\n",
      "Score: 0.70840263\n",
      "\n",
      "---\n",
      "\n",
      "ID: 43vjyowBidHK_OJxJozM\n",
      "Doc Title: New Employee Onboarding Guide\n",
      "parent text:\n",
      "\n",
      "Passage Text:\n",
      "Review benefits options: Carefully review the benefits package and choose the options that best meet your needs.\n",
      "Complete enrollment forms: Fill out the necessary forms to enroll in your chosen benefits. Submit these forms to the HR department within 30 days of your start date.\n",
      "Designate beneficiaries: If applicable, designate beneficiaries for your life insurance and retirement plans.\n",
      "Getting Settled in Your Workspace\n",
      "To help you feel comfortable and productive in your new workspace, take the following steps:\n",
      "\n",
      "\n",
      "Score: 0.6890813\n",
      "\n",
      "---\n"
     ]
    }
   ],
   "source": [
    "response = client.search(\n",
    "    index=INDEX_NAME,\n",
    "    knn={\n",
    "        \"inner_hits\": {\"size\": 1, \"_source\": False, \"fields\": [\"passages.text\"]},\n",
    "        \"field\": \"passages.vector.predicted_value\",\n",
    "        \"k\": 5,\n",
    "        \"num_candidates\": 100,\n",
    "        \"query_vector_builder\": {\n",
    "            \"text_embedding\": {\n",
    "                \"model_id\": \"sentence-transformers__all-minilm-l6-v2\",\n",
    "                \"model_text\": \"Whats the work from home policy?\",\n",
    "            }\n",
    "        },\n",
    "    },\n",
    ")\n",
    "\n",
    "pretty_response(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### With Langchain\n",
    "We can also peform this search within Langchain with an adjustment to the query.\n",
    "\n",
    "We also override the `doc_builder` to populate the `site_content` with the passages rather than the full document."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Doc title: Work From Home Policy\n",
      "Text:\n",
      "Effective: March 2020\n",
      "Purpose\n",
      "\n",
      "The purpose of this full-time work-from-home policy is to provide guidelines and support for employees to conduct their work remotely, ensuring the continuity and productivity of business operations during the COVID-19 pandemic and beyond.\n",
      "Scope\n",
      "\n",
      "This policy applies to all employees who are eligible for remote work as determined by their role and responsibilities. It is designed to allow employees to work from home full time while maintaining the same level of performance and collaboration as they would in the office.\n",
      "Eligibility\n",
      "\n",
      "\n",
      "Doc title: Intellectual Property Policy\n",
      "Text:\n",
      "Purpose\n",
      "The purpose of this Intellectual Property Policy is to establish guidelines and procedures for the ownership, protection, and utilization of intellectual property generated by employees during their employment. This policy aims to encourage creativity and innovation while ensuring that the interests of both the company and its employees are protected.\n",
      "\n",
      "Scope\n",
      "This policy applies to all employees, including full-time, part-time, temporary, and contract employees.\n",
      "\n",
      "\n",
      "Doc title: Company Vacation Policy\n",
      "Text:\n",
      "Purpose\n",
      "\n",
      "The purpose of this vacation policy is to outline the guidelines and procedures for requesting and taking time off from work for personal and leisure purposes. This policy aims to promote a healthy work-life balance and encourage employees to take time to rest and recharge.\n",
      "Scope\n",
      "\n",
      "This policy applies to all full-time and part-time employees who have completed their probationary period.\n",
      "Vacation Accrual\n",
      "\n",
      "\n",
      "Doc title: Wfh Policy Update May 2023\n",
      "Text:\n",
      "As we continue to prioritize the well-being of our employees, we are making a slight adjustment to our hybrid work policy. Starting May 1, 2023, employees will be required to work from the office three days a week, with two days designated for remote work. Please communicate with your supervisor and HR department to establish your updated in-office workdays.\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from langchain_elasticsearch import (\n",
    "    ElasticsearchStore,\n",
    "    ApproxRetrievalStrategy,\n",
    ")\n",
    "from typing import List, Union\n",
    "from langchain_core.documents import Document\n",
    "\n",
    "\n",
    "class CustomRetrievalStrategy(ApproxRetrievalStrategy):\n",
    "\n",
    "    def query(\n",
    "        self,\n",
    "        query: Union[str, None],\n",
    "        filter: List[dict],\n",
    "        **kwargs,\n",
    "    ):\n",
    "\n",
    "        es_query = {\n",
    "            \"knn\": {\n",
    "                \"inner_hits\": {\"_source\": False, \"fields\": [\"passages.text\"]},\n",
    "                \"field\": \"passages.vector.predicted_value\",\n",
    "                \"filter\": filter,\n",
    "                \"k\": 5,\n",
    "                \"num_candidates\": 100,\n",
    "                \"query_vector_builder\": {\n",
    "                    \"text_embedding\": {\n",
    "                        \"model_id\": \"sentence-transformers__all-minilm-l6-v2\",\n",
    "                        \"model_text\": query,\n",
    "                    }\n",
    "                },\n",
    "            }\n",
    "        }\n",
    "\n",
    "        return es_query\n",
    "\n",
    "\n",
    "vector_store = ElasticsearchStore(\n",
    "    index_name=INDEX_NAME,\n",
    "    es_connection=client,\n",
    "    query_field=\"content\",\n",
    "    strategy=CustomRetrievalStrategy(),\n",
    ")\n",
    "\n",
    "\n",
    "def doc_builder(hit):\n",
    "    passage_hits = (\n",
    "        hit.get(\"inner_hits\", {}).get(\"passages\", {}).get(\"hits\", {}).get(\"hits\", [])\n",
    "    )\n",
    "    page_content = \"\"\n",
    "    for passage_hit in passage_hits:\n",
    "        passage_fields = passage_hit.get(\"fields\", {}).get(\"passages\", [])[0]\n",
    "        page_content += passage_fields.get(\"text\", [])[0] + \"\\n\\n\"\n",
    "\n",
    "        return Document(\n",
    "            page_content=page_content,\n",
    "            metadata=hit[\"_source\"][\"metadata\"],\n",
    "        )\n",
    "\n",
    "\n",
    "results = vector_store.similarity_search(\n",
    "    query=\"Whats the work from home policy?\", doc_builder=doc_builder\n",
    ")\n",
    "for result in results:\n",
    "    print(f'Doc title: {result.metadata[\"name\"]}')\n",
    "    print(f\"Text:\\n{result.page_content}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
